#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
Module: Data utility functions. Provides various helper functions
for data processing used in other modules.
"""
__author__ = "Shubhranshu Shekhar"
__date__ = "12/8/22"

import matplotlib
import numpy as np
import pandas as pd

matplotlib.use('Agg')
import matplotlib.pyplot as plt

matplotlib.style.use('ggplot')
import seaborn as sns
from collections import Counter
from functools import reduce
from collections import OrderedDict
from scipy.spatial.distance import squareform
from scipy.spatial.distance import pdist, jaccard
from collections import defaultdict

import argparse
import logging
import os
import json
from random import randint
import pickle
import time
import bz2
import chardet

from sklearn.feature_extraction.text import CountVectorizer
from sklearn.neighbors import NearestNeighbors, KDTree
from sklearn.linear_model import LassoCV




# defining some utility functions
def touch(file_path):
    '''
    Unix style touch functionality for creating an empty file updated with time of creation.
    :param bsfcc_path:
    '''
    with open(file_path, 'a'):
        os.utime(file_path, None)


def move_files(src: str, dst: str, pattern: str = '*.json'):
    if not os.path.isdir(dst):
        pathlib.Path(dst).mkdir(parents=True, exist_ok=True)
    for f in fnmatch.filter(os.listdir(src), pattern):
        shutil.move(os.path.join(src, f), os.path.join(dst, f))


def mkdir_p(mypath):
    '''Creates a directory. equivalent to using mkdir -p on the command line'''

    from errno import EEXIST
    from os import makedirs,path

    try:
        makedirs(mypath)
    except OSError as exc: # Python >2.5
        if exc.errno == EEXIST and path.isdir(mypath):
            pass
        else: raise


# program starts from here
OUTPUTPATH = '../output/'
mkdir_p(OUTPUTPATH)

# next few lines are setting up logging for this src to record errors
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)

# create a file handler
handler = logging.FileHandler(OUTPUTPATH + str(os.path.basename(__file__)) + '.log')
handler.setLevel(logging.DEBUG)

# create a logging format
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
handler.setFormatter(formatter)

# add the file handler to the logger
logger.addHandler(handler)

# logging pandas version to see why certain errors occur only at remote server
logger.info("pandas version used: " + str(pd.__version__))


try:
    import shutil, pathlib, fnmatch
except Exception as exception:
    logger.error("MyError: Failed import shutil, pathlib, fnmatch", exc_info=True)


class NumpyEncoder(json.JSONEncoder):
    """ Custom encoder for numpy data types 
        Required for storing data as json files
    """

    def default(self, obj):
        if isinstance(obj, (np.int_, np.intc, np.intp, np.int8,
                            np.int16, np.int32, np.int64, np.uint8,
                            np.uint16, np.uint32, np.uint64)):

            return int(obj)

        elif isinstance(obj, (np.float_, np.float16, np.float32, np.float64)):
            return float(obj)

        elif isinstance(obj, (np.complex_, np.complex64, np.complex128)):
            return {'real': obj.real, 'imag': obj.imag}

        elif isinstance(obj, (np.ndarray,)):
            return obj.tolist()

        elif isinstance(obj, (np.bool_)):
            return bool(obj)

        elif isinstance(obj, (np.void)):
            return None

        return json.JSONEncoder.default(self, obj)


def jaccard_similarity(list1, list2):
    intersection = len(list(set(list1).intersection(list2)))
    union = (len(list1) + len(list2)) - intersection
    return float(intersection) / union


def create_drg_icd_dicts(df_, drg_col_, icd_cols_):
    lst_of_dicts = []
    for col_ in icd_cols_:
        dict_ = df_.groupby(drg_col_)[col_].agg(lambda x: x.value_counts().to_dict()).to_dict()
        lst_of_dicts.append(dict_)
    return lst_of_dicts


def merge_dicts(a, b):
    keys_ = [list(d.keys()) for d in (a, b)]
    keys_ = set([item for sublist in keys_ for item in sublist])

    ret_dict = {}
    for key in keys_:
        if (key in a) and (key in b):
            temp_ = dict(Counter(a[key]) + Counter(b[key]))
            ret_dict[key] = temp_
        elif key in a:
            ret_dict[key] = a[key]
        elif key in b:
            ret_dict[key] = b[key]

    return OrderedDict(sorted(ret_dict.items(), key=lambda t: t[0]))


def my_hamming(a, b):
    return np.count_nonzero(a != b)


def my_hamming_s(a, b):  # sparse version
    return np.sum(a != b)


def plot_top_k_bar(df, cat_col, limit, save_path):
    '''
    Plot top k of a given categorical column
    '''
    sns.set()
    bar_plot = dict(Counter(df[cat_col].values).most_common(limit))
    fig, ax = plt.subplots(figsize=(18, 5))
    ax.bar(*zip(*bar_plot.items()))
    plt.xticks(rotation=90)
    plt.xlabel(cat_col)
    plt.ylabel("Frequency")
    plt.savefig(save_path, bbox_inches='tight', dpi=50)


def plot_bar_based_on_amount(df, cat_col, num_col, limit, save_path):
    '''
    Plot top k of a given categorical column
    num_col is used to find out which category has highest amount
    '''
    sns.set()
    df = df.reset_index(drop=True)
    df = df.sort_values([num_col], ascending=[0]).head(limit)
    temp_df = df.groupby([cat_col])[[num_col]].mean().reset_index().sort_values([num_col],
                                                                                ascending=[0]).head(limit)
    bar_plot = dict(zip(temp_df[cat_col], temp_df[num_col]))
    fig, ax = plt.subplots(figsize=(18, 5))
    ax.bar(*zip(*bar_plot.items()))
    plt.xticks(rotation=90)
    plt.xlabel(cat_col)
    plt.ylabel(num_col)
    if save_path is None:
        plt.show()
    else:
        plt.savefig(save_path, bbox_inches='tight', dpi=50)


def boxplot_top_k_categories(df, cat_col, num_col, limit, save_path, violin=False):
    '''
    Box plot for a numerical fields upto a top limit -- limited number -- of categories
    :return:
    '''
    keys = []
    for i in dict(Counter(df[cat_col].values).most_common(limit)):
        keys.append(i)

    # filter df based on the keys
    df_new = df[df[cat_col].isin(keys)]
    
    fig, ax = plt.subplots(figsize=(18, 5))
    if violin:
        ax = sns.violinplot(x=df_new[cat_col], y=df_new[num_col], ax=ax, order=keys)
    else:
        ax = sns.boxplot(x=df_new[cat_col], y=df_new[num_col], ax=ax, order=keys)
    ax.set_xticklabels(ax.get_xticklabels(), rotation=90)
    plt.savefig(save_path, bbox_inches='tight', dpi=50)


def boxplot_top_k_cats_based_on_pmt(df, cat_col, num_col, limit, save_path, violin=False, sort_col=None,
                                    sort_metric=None):
    '''
    Box plot for a numerical fields upto a top limit based on pmt
    :return:
    '''
    if sort_col is None:
        sort_col = num_col

    if sort_metric:
        temp_df = df.groupby([cat_col])[sort_col].agg({sort_metric}).reset_index().sort_values([sort_metric],
                                                                                               ascending=[0]).head(
            limit)
    else:
        temp_df = df.groupby([cat_col])[[sort_col]].mean().reset_index().sort_values([sort_col],
                                                                                     ascending=[0]).head(limit)

    keys = temp_df[cat_col].tolist()

    # filter df based on the keys
    df_new = df[df[cat_col].isin(keys)]

    fig, ax = plt.subplots(figsize=(18, 5))
    if violin:
        ax = sns.violinplot(x=df_new[cat_col], y=df_new[num_col], ax=ax, order=keys)
    else:
        ax = sns.boxplot(x=df_new[cat_col], y=df_new[num_col], ax=ax, order=keys)
    ax.set_xticklabels(ax.get_xticklabels(), rotation=90)
    plt.savefig(save_path, bbox_inches='tight', dpi=10)


def pct_share_amt(df, pmt_col, drg_col, drg_grp_col):
    # filter out rwos that occur less than 11 times in this group
    df = df.groupby([drg_grp_col, drg_col]).filter(lambda x: len(x) > 10).reset_index(drop=True)

    tdf = df.groupby([drg_grp_col, drg_col])[pmt_col].agg(['count']).rename(columns={'count': pmt_col}) \
        .groupby(level=0).apply(lambda x: 100 * x / float(x.sum())) \
        .reset_index().groupby([drg_grp_col]).first().reset_index()

    if pmt_col not in tdf.columns:
        logger.info("Columns in tdf of pct_share_amt method:" + ','.join(tdf.columns))
        logger.info("df of this tdf has rows:" + str(len(df)))
    dict_grp_pct_shr = dict(zip(tdf[drg_grp_col], tdf[pmt_col]))
    dict_grp_drg_map = dict(zip(tdf[drg_grp_col], tdf[drg_col]))
    return dict_grp_pct_shr, dict_grp_drg_map


def parse_args():
    '''
    For now we will use default values populated in this function
    Function to read arguments from command line
    :return: dictionary of arguments with values populated from command line
    '''
    parser = argparse.ArgumentParser()
    parser.add_argument('--savepath', default='../output/')

    # filepath to an inpatient record database .dta
    parser.add_argument('--input', type=str, default='/disk/aging/medicare/data/harm/100pct/ip/<year>/ipc<year>.dta')
    parser.add_argument('--tempworkspace', type=str,
                        default='/homes/nber/shubhras-dua57882/tmp/')
    parser.add_argument('--bsfcc', type=str, default='/disk/aging/medicare/data/harm/<PCT>/bsfcc/<year>/bsfcc<year>.dta')
    parser.add_argument('--bsf', type=str, default='/disk/aging/medicare/data/harm/<PCT>/bsf/<year>/bsfab<year>.dta')
    parser.add_argument('--op', type=str, default='/disk/aging/medicare/data/harm/<PCT>/op/<year>/opc<year>.dta')
    parser.add_argument('--car', type=str, default='/disk/aging/medicare/data/harm/<PCT>/car/<year>/carc<year>.dta')
    parser.add_argument('--year', type=str, default='9999')
    parser.add_argument('--pct', type=str, default='0001pct')
    args, _ = parser.parse_known_args()
    return args


def read_data(args, year, grp_drg_mapping=None):
    # main data loading
    try:
        if args is None:
            args = parse_args()
        input_path = args.input
        input_path = input_path.replace('<year>', year)

        save_path = args.savepath + year + '/'
        os.makedirs(save_path, exist_ok=True)

        # fields that I'll load for in this script
        drg_col = "drg_cd"
        pmt_col = "pmt_amt"
        provider_col = "provider"  # "orgnpinm"

        useful_fields = [provider_col, "bene_id", "clm_id", "pmt_amt", "drg_cd", "at_npi",
                         "disp_shr", "ime_amt", "outlrpmt", "outlr_cd", "type_adm",
                         "prpay_cd", "pps_ind", "drgwtamt", "thru_dt", "from_dt",
                         'icd_dgns_cd1', 'icd_dgns_cd2', 'icd_dgns_cd3', 'icd_dgns_cd4', 'icd_dgns_cd5',
                         'icd_dgns_cd6', 'icd_dgns_cd7', 'icd_dgns_cd8', 'icd_dgns_cd9', 'icd_dgns_cd10',
                         "icd_prcdr_cd1", "icd_prcdr_cd2", "icd_prcdr_cd3", "icd_prcdr_cd4", "icd_prcdr_cd5",
                         "icd_prcdr_cd6", "zip_cd"
                         ]
       
        icd9_cols = ['icd_dgns_cd1', 'icd_dgns_cd2', 'icd_dgns_cd3', 'icd_dgns_cd4', 'icd_dgns_cd5',
                     'icd_dgns_cd6', 'icd_dgns_cd7', 'icd_dgns_cd8', 'icd_dgns_cd9', 'icd_dgns_cd10']

        # loading data in chuncks to avoid any memory overflow issues
        # loading only the fileds in variable useful_fields
        itr = pd.read_stata(input_path, chunksize=1000000)
        lst = []

        appended_dsh_op = False
        for chunk in itr:
            if (not appended_dsh_op) and 'dsh_op' in chunk.columns:
                useful_fields.append('dsh_op')
                appended_dsh_op = True
            lst.append(chunk[useful_fields])

        # fix disproportionate share column name
        disp_col = 'dsh_op'

        # loaded dataframe
        df = pd.concat(lst).reset_index(drop=True)
        logger.info("Data shape:" + str(df.shape))

        # Log number of unique DRG codes
        n_drg = df[drg_col].nunique()
        n_drg_nonnull = df[drg_col].count()
        n_provider = df[provider_col].nunique()
        logger.info("Unique DRG codes, %0.2f (all), %0.2f (non-null)", n_drg, n_drg_nonnull)
        logger.info("Unique providers, %0.2f ", n_provider)

        df = cell_size_suppression(df, key_lst=[provider_col, "bene_id"], 
                                   supp_key_idx=1) # remove provider with less beneficiaries

        # load DRG grouping
        # this is required to create a new column to map DRG onto groups
        if grp_drg_mapping:
            drg_grp_dict = grp_drg_mapping
        else:
            drg_grp_dict = json.load(open("../data/grouped_drg.json"))

    except Exception as exception:
        logger.error("Data loading...", exc_info=True)
        print(str(exception))
        print("Exiting program!")
        exit(0)

    # mapping to created DRG group -- basically contiguos DRG codes related to one diagnosis
    drg_grp_col = "drg_grp"
    try:
        df[drg_grp_col] = df[drg_col].map(drg_grp_dict)
        # log unique group counts
        # Log number of unique DRG codes
        n_drg_grp = df[drg_grp_col].nunique()
        n_drg_grp_nonnull = df[drg_grp_col].count()
        logger.info("Unique DRG Group codes, %0.2f (all), %0.2f (non-null)", n_drg_grp, n_drg_grp_nonnull)
    except Exception as exception:
        logger.error("MyError: Failed to map drg to DRG grouping", exc_info=True)

    
    # load mdc df
    mdc_col = 'MDC'
    try:
        mdc_df = pd.read_csv('../data/MS-DRG-MDC-Mapping.csv').fillna('99')
        mdc_df['MSDRG'] = mdc_df['MSDRG'].astype(float)
        df.loc[:, 'fl_CLM_DRG_CD'] = pd.to_numeric(df[drg_col], errors='coerce')
        df = pd.merge(df, mdc_df, left_on='fl_CLM_DRG_CD', right_on='MSDRG')
    except Exception as e:
        logger.error("MyError: MDC merging failed", exc_info=True)

    # mapping clm_pmt_amt to the base price
    # base_price = clm_pmt_amt - disp_shr - ime_amt - outlrpmt
    base_clm_pmt = "base_clm_amt"
    try:
        # create index to select valid rows
        select_idx = (df[disp_col] < df[pmt_col])  # & \
        if sum(select_idx) < 10000:  # just a check to see if select index returns 0 rows
            select_idx = (df[disp_col] < df[pmt_col])

        logger.info("The loaded data will retain %0.2f rows", sum(select_idx))
        df = df[select_idx].reset_index()

        # set outlier payment to 0 if outlier src is null
        df.loc[df['outlr_cd'].isnull(), 'outlrpmt'] = 0.0
        df[base_clm_pmt] = df[pmt_col] - df[disp_col].fillna(0) - df["ime_amt"].fillna(0) - df["outlrpmt"].fillna(0)
        # logger.info("The data is loaded with %0.2f rows", df.shape[0])
    except Exception as exception:
        logger.error("MyError: Failed to obtain base claim amount.", exc_info=True)


    # create length-of-stay field los
    los = "los"
    try:
        df["from_dt"] = pd.to_datetime(df["from_dt"].apply(str), infer_datetime_format=True)
        df["thru_dt"] = pd.to_datetime(df["thru_dt"].apply(str), infer_datetime_format=True)

        df[los] = (df['thru_dt'] - df['from_dt']).dt.days + 1
    except Exception as exception:
        logger.error("MyError: LOS computation failed", exc_info=True)
        try:
            df["los"] = (df['thru_dt'] - df['from_dt']).dt.days + 1
        except Exception as exception:
            logger.error("MyError: Trial 2 - LOS computation failed", exc_info=True)
    
    # keep providers that are hopsitals
    prvdr_num = df[["provider"]].drop_duplicates()
    pos_cols = ["PRVDR_CTGRY_SBTYP_CD", "PRVDR_CTGRY_CD", "PRVDR_NUM"]
    pos = pd.read_csv('../data/pos_other_sep20.csv', skipinitialspace=True, usecols=pos_cols, 
                      encoding_errors='ignore') # provider of service data

    # merge the provider num data with pos data
    prvdr_data = pd.merge(prvdr_num, pos, left_on="provider", right_on='PRVDR_NUM', how='inner')
    filtered_prvdr_data = prvdr_data[(prvdr_data['PRVDR_CTGRY_CD'] == 1) & 
                                     (prvdr_data['PRVDR_CTGRY_SBTYP_CD'] == 1)]
    
    filtered_prvdr_data[["PRVDR_NUM"]].to_csv('../output/hospitals_type1_subtype1_inpatient.csv', index=False)

    # restrict data to hopitals only
    valid_providers_df = filtered_prvdr_data[['PRVDR_NUM']].copy()
    
    # Perform inner join to keep only rows where provider exists in valid_providers
    df = df.merge(valid_providers_df, left_on='provider', right_on='PRVDR_NUM', how='inner')
    df = df.drop('PRVDR_NUM', axis=1).reset_index(drop=True)
    logger.info("After filtering to hospitals only, data shape:" + str(df.shape))

    return df


def read_chronic_conditions_data(bsfcc_path, year, pct='100pct'):
    try:
        bsfcc_path = bsfcc_path.replace('<year>', year).replace('<PCT>', pct)

        bsfcc_fields = ["bene_id", "ami", "alzh", "alzhdmta", "atrialfb", "cataract", "chrnkidn",
                        "copd", "chf", "diabetes", "glaucoma", "hipfrac", "ischmcht", "depressn",
                        "osteoprs", "ra_oa", "strketia", "cncrbrst", "cncrclrc", "cncrprst",
                        "cncrlung", "cncrendm", "anemia", "asthma", "hyperl", "hyperp", "hypert",
                        "hypoth"]
        chronic_fields = bsfcc_fields[1:]

        # read data first
        itr = pd.read_stata(bsfcc_path, chunksize=1000000)
        lst = []
        for chunk in itr:
            lst.append(chunk[bsfcc_fields])

        # loaded dataframe
        df = pd.concat(lst).reset_index(drop=True)
        logger.info("BSFCC Data shape:" + str(df.shape))

        # create dict for translating the flags of dieases
        #flag_dict = {0: 0, 1: 1, 2:0, 3:1}
        df[chronic_fields] = df[chronic_fields] % 2  # 0 --> absence of disease, 1 --> presence

    except Exception as exception:
        logger.error("BSFCC Data loading...", exc_info=True)
        print("Exiting program!")
        exit(0)

    return df, bsfcc_fields


def chronic_cdns_df(cdf, bdf, ch_fields):
    '''
        :param cdf: claims dataframe
        :param bdf: chronic conditions dataframe
        :param ch_fields: names of chrnic conditions
        '''
    try:
        # Step 1: Merge the two df based of bene_id
        df = cdf.merge(bdf, left_on='bene_id', right_on='bene_id').reset_index(drop=True)

        # Step 2: Select provider col and only the chronic conditions
        select_fields = ['provider'] + ch_fields[1:]
        df = df[select_fields]

        # Step 3: Group by provider and sum the values as it contains 0/1 per chonic cdns
        prvdr_chronic = df.groupby(['provider']).sum()
        return prvdr_chronic
    except Exception as e:
        logger.error("MyError: failed with kNN", exc_info=True)
        return None


def chronic_cdns_peers(cdf, bdf, ch_fields):
    '''
    :param cdf: claims dataframe
    :param bdf: chronic conditions dataframe
    :param ch_fields: names of chrnic conditions
    '''
    try:
        # Step 1: Merge the two df based of bene_id
        df = cdf.merge(bdf, left_on='bene_id', right_on='bene_id').reset_index(drop=True)

        # Step 2: Select provider col and only the chronic conditions
        select_fields = ['provider'] + ch_fields[1:]
        df = df[select_fields]

        # Step 3: Group by provider and sum the values as it contains 0/1 per chonic cdns
        prvdr_chronic = df.groupby(['provider']).sum()
        prvdr_codes = prvdr_chronic.index  # store the provider names in this variable
        chronic_columns = prvdr_chronic.columns
        chronic_index = dict(zip(list(chronic_columns), range(len(chronic_columns))))

        # normalizing the data -- creating probability distributions
        prvdr_chr_arr = prvdr_chronic.values
        prvdr_chr_arr = prvdr_chr_arr / prvdr_chr_arr.sum(axis=1, keepdims=True)

        # create peers dict based on CHR representation
        # neigh = NearestNeighbors(n_neighbors=2760, radius=0.3)
        neigh = NearestNeighbors(n_neighbors=20, radius=0.2)
        nbrs = neigh.fit(prvdr_chr_arr)
        distances, indices = nbrs.kneighbors(prvdr_chr_arr)

        nn_selected = np.where(distances <= 0.2, True, False)
        chr_peers_dict = {}
        for i in range(len(prvdr_codes)):
            chr_peers_dict[prvdr_codes[i]] = prvdr_codes[indices[i][nn_selected[i]][1:]]

    except Exception as e:
        logger.error("MyError: failed with kNN", exc_info=True)

    return chr_peers_dict, prvdr_codes, chronic_columns, prvdr_chr_arr


def get_count_dct(df, count_key, count_col):
    pr_grouped = df.groupby([count_key])
    counts_dct = {}
    for nm, grp_df in pr_grouped:
        try:
            dct_ = grp_df[count_col].value_counts().to_dict()
            counts_dct[nm] = dct_
        except:
            logger.info("MyError: get_count_dct failed for key: " + str(nm))
            continue
    return counts_dct


# functions to create features
def get_DRG_distribution(df_, prvdr_, mdc_, drg_weight_dct, mdc_drg_dct, provider_col="provider"):
    if mdc_ is None:
        drg_lst_ = [d_ for d_ in drg_weight_dct.keys()]
    else:
        drg_lst_ = mdc_drg_dct[mdc_]

    drg_wt_ = [drg_weight_dct[d_] for d_ in drg_lst_]
    sorted_drg_lst_ = [x for _, x in sorted(zip(drg_wt_, drg_lst_), reverse=True)]

    selected_df = df_[(df_[provider_col] == prvdr_) & (df_['MSDRG'].isin(sorted_drg_lst_))]
    counts_dct = dict(selected_df['MSDRG'].value_counts())

    g_cnts = [counts_dct[v] if v in counts_dct else 0. for v in sorted_drg_lst_]
    g_sum = sum(g_cnts)

    return [x / g_sum if g_sum > 0 else 0 for x in g_cnts], sorted_drg_lst_

# convenience function to plot cluster centers
def plot_cluster_centers(c_center, xticks=None, ylim=None, save_path=None):
    if len(c_center) > 150:
        plt.figure(figsize=(16, 1))
        jump = 4
    else:
        plt.figure(figsize=(8, 1))
        jump = 2
    plt.bar(np.arange(len(c_center)), c_center)
    if xticks:
        plt.xticks(range(0, len(xticks), jump), xticks[::jump], rotation=90)
    else:
        plt.xticks(range(0, len(c_center), jump))

    if ylim:
        plt.ylim(0, ylim)

    if save_path:
        os.makedirs(os.path.dirname(save_path), exist_ok=True)
        plt.savefig(save_path, bbox_inches='tight', dpi=50)
    else:
        plt.show()


def get_features(tuple_dicts, peers):
    num_feats = len(tuple_dicts)

    lst_single_feat_dcts = []
    for dct in tuple_dicts:
        lst_single_feat_dcts.append(get_feat_for_a_dict(dct, peers))

    prvdrs = list(lst_single_feat_dcts[0].keys())
    feats = []
    for pr in prvdrs:
        local_feat = [dct[pr] for dct in lst_single_feat_dcts]
        feats.append(local_feat)

    return np.array(feats), prvdrs


def get_feat_for_a_dict(dct_, peers):
    pr_val = defaultdict()
    for pr_, v in dct_.items():
        if pr_ not in peers:
            continue
        pr_peers = peers[pr_]
        if len(pr_peers) < 5:
            continue
        avg_peers = np.array([dct_[kp] for kp in pr_peers if kp in dct_])
        if avg_peers.shape[0] < 3:
            continue
        avg_peers = np.mean(avg_peers)
        # d = np.linalg.norm(v - avg_peers)
        d = np.sum(v - avg_peers)
        pr_val[pr_] = d
    return pr_val


def get_avgWt(claims_df, drg_wt_dct, grp_col_index, grp_col, drg_col="drg_cd", prvdr_col="provider"):
    provider_avg_wt = defaultdict()
    grouped = claims_df.groupby([prvdr_col])
    for nm_, grp_df in grouped:
        pr_ = nm_
        avg_wt_vec = np.zeros(len(grp_col_index))
        mdc_prvdr_grp = grp_df.groupby(grp_col)  # 'MDC'
        for mdc_, m_grp_df in mdc_prvdr_grp:
            idx = grp_col_index[mdc_]  # str(mdc_)
            uniq, cnts = np.unique(m_grp_df[drg_col], return_counts=True)
            wt_sum = 0
            for drg_, cnt_ in zip(uniq, cnts):
                wt_sum = wt_sum + drg_wt_dct[drg_] * cnt_
            avg_wt_vec[idx] = wt_sum / sum(cnts)
        provider_avg_wt[pr_] = avg_wt_vec
    return provider_avg_wt


def get_relative_weight(claims_df, prvdr_avg_wts, grp_col_index, grp_col, prvdr_col="provider"):
    provider_rw = defaultdict()
    grouped = claims_df.groupby([prvdr_col])
    for nm_, grp_df in grouped:
        pr_ = nm_
        avg_wt_vec = prvdr_avg_wts[pr_].copy()
        mdc_prvdr_grp = grp_df.groupby(grp_col)  # 'MDC'
        for mdc_, m_grp_df in mdc_prvdr_grp:
            idx = grp_col_index[mdc_]  # str(mdc_)
            total_los = sum(m_grp_df['amLOS'])
            cnt_clm = len(m_grp_df)
            alos = total_los / cnt_clm
            avg_wt_vec[idx] = avg_wt_vec[idx] / alos
        provider_rw[pr_] = avg_wt_vec
    return provider_rw


def pct_share_amt(df, pmt_col, drg_col, drg_grp_col):
    tdf = df.groupby([drg_grp_col, drg_col])[pmt_col].agg(['count']).rename(columns={'count': pmt_col}) \
        .groupby(level=0).apply(lambda x: 100 * x / float(x.sum())) \
        .reset_index().groupby([drg_grp_col]).first().reset_index()

    if pmt_col not in tdf.columns:
        logger.info("Columns in tdf of pct_share_amt method:" + ','.join(tdf.columns))
        logger.info("df of this tdf has rows:" + str(len(df)))
    dict_grp_pct_shr = dict(zip(tdf[drg_grp_col], tdf[pmt_col]))
    dict_grp_drg_map = dict(zip(tdf[drg_grp_col], tdf[drg_col]))
    return dict_grp_pct_shr, dict_grp_drg_map


def get_pct_shr_prvdr(df, provider_col, pmt_col, drg_col, drg_grp_col):
    pr_grouped = df.groupby([provider_col])
    pct_shr_dct = {}
    for nm, grp_df in pr_grouped:
        try:
            pct_shr_, grp_drg_ = pct_share_amt(grp_df, pmt_col, drg_col, drg_grp_col)
            pct_shr_dct[nm] = pct_shr_
        except:
            print("MyInfo: pct_share_amt failed for provider: " + str(nm))
            continue
    return pct_shr_dct


def convert_pct_shr_dct_to_vec(prvdr_pct_share, drg_grp_idx):
    prvdr_shr = defaultdict()
    for k, v in prvdr_pct_share.items():
        vec = np.zeros(len(drg_grp_idx))
        for k_, v_ in v.items():
            idx = drg_grp_idx[k_]
            vec[idx] = v_
        prvdr_shr[k] = vec
    return prvdr_shr


def cell_size_suppression(df_, key_lst=None, supp_key_idx=1):
    df = df_.copy()
    if key_lst is None:
        key_lst = ["provider", "bene_id"]
    count_col = key_lst[supp_key_idx]
    counts_prvdr_ = df.groupby(key_lst, sort=False)[count_col].transform('size')
    df = df[counts_prvdr_ >= 11].reset_index(drop=True)  # at least 11 claims should be there
    return df


def get_peers(args, df, prvdr_mdc_counts_dct, provider_col, drg_col, peer_type='MDC'):
    '''
    :param args:
    :param df:
    :param prvdr_mdc_counts_dct:
    :param provider_col:
    :param drg_col:
    :param peer_type: 'MDC', BFCC, Combined, ICD9
    :return:
    '''
    # do knn to find peer hospitals based on both MDC
    try:
        # mdc codes as columns
        prvdr_mdc_df = pd.DataFrame.from_dict(prvdr_mdc_counts_dct, orient='index').fillna(0)
        prvdr_mdc_df.columns = prvdr_mdc_df.columns.astype(str)
        prvdr_mdc_df = prvdr_mdc_df.reindex(sorted(prvdr_mdc_df.columns), axis=1)
        # normalize
        prvdr_mdc_df = prvdr_mdc_df.div(prvdr_mdc_df.sum(axis=1), axis=0)

        # prvdr_codes = prvdr_mdc_df.index  # store the provider names in this variable
        mdc_columns = prvdr_mdc_df.columns
        mdc_index = dict(zip(list(mdc_columns), range(len(mdc_columns))))

        if peer_type == 'MDC':
            # normalizing the data
            prvdr_mdc_arr = prvdr_mdc_df.values
            prvdr_mdc_arr = prvdr_mdc_arr / prvdr_mdc_arr.sum(axis=1, keepdims=True)

            # create peers dict based on MDC representation
            neigh = NearestNeighbors(n_neighbors=50, radius=0.4)
            nbrs = neigh.fit(prvdr_mdc_arr)
            distances, indices = nbrs.kneighbors(prvdr_mdc_arr)

            nn_selected = np.where(distances <= 0.2, True, False) # choose 80% similarity
            prvdr_codes = prvdr_mdc_df.index
            mdc_peers_dict = {}
            for i in range(len(prvdr_codes)):
                mdc_peers_dict[prvdr_codes[i]] = prvdr_codes[indices[i][nn_selected[i]][1:]]

            return mdc_peers_dict, mdc_index, prvdr_codes
    except Exception as e:
        logger.error("MyError: failed with kNN", exc_info=True)

    if peer_type == 'ICD9':
        # icd9 peers
        with open('../data/icd9_peers.pkl', 'rb') as f:  # values
            icd9_peers_dict = pickle.load(f)
        return icd9_peers_dict, mdc_index, prvdr_mdc_df.index

    # find peers based on mbsfcc data
    try:
        counts_prvdr_drg = df.groupby([provider_col, drg_col], sort=False)[drg_col].transform('size')
        cdf = df[counts_prvdr_drg >= 11].reset_index(drop=True)  # at least 11 claims should be there

        # reading chronic conditions file
        yr = str(int(args.year) - 1)
        bdf, bsfcc_cols = read_chronic_conditions_data(args.bsfcc, yr)  # '2007'

        # get chronic conditions dataframe
        prvdr_chronic_cdns_df = chronic_cdns_df(cdf, bdf, bsfcc_cols)
        # normalizing
        prvdr_chronic_cdns_df = prvdr_chronic_cdns_df.div(prvdr_chronic_cdns_df.sum(axis=1), axis=0)

    except Exception as e:
        logger.error("MyError: In bsfcc based peers creation", exc_info=True)

    # find combined neighbors based on MDC and chronic conditions
    try:
        # 1. combine the two dataframe
        combined_df = prvdr_mdc_df.merge(prvdr_chronic_cdns_df, left_index=True, right_index=True)
        mdc_columns = combined_df.columns
        prvdr_codes = combined_df.index  # store the provider names in this variable

        # Finf peers based on combined representation
        prvdr_mdc_arr = combined_df.values
        neigh = NearestNeighbors(n_neighbors=50, radius=0.4)
        nbrs = neigh.fit(prvdr_mdc_arr)
        distances, indices = nbrs.kneighbors(prvdr_mdc_arr)

        nn_selected = np.where(distances <= 0.2, True, False)
        mdc_peers_dict = {}
        for i in range(len(prvdr_codes)):
            mdc_peers_dict[prvdr_codes[i]] = prvdr_codes[indices[i][nn_selected[i]][1:]]

    except Exception as e:
        logger.error("MyError: In combined provider representation", exc_info=True)

    # save mdc peers for later reference
    try:
        save_mdc_peers = {}
        for k, v in mdc_peers_dict.items():
            save_mdc_peers[k] = v.tolist()
    except Exception as e:
        logger.error("MyError: saving mdc based peers", exc_info=True)
    return mdc_peers_dict, mdc_index, prvdr_codes


def save_provider_icd9mapping(args, year, grp_drg_mapping=None):
    save_path = args.savepath + year + '/'
    os.makedirs(save_path, exist_ok=True)

    df = read_data(args=args, year=year, grp_drg_mapping=grp_drg_mapping)

    # book keeping
    provider_col = "provider"
    drg_col = "drg_cd"

    icd9_cols = ['icd_dgns_cd1', 'icd_dgns_cd2', 'icd_dgns_cd3', 'icd_dgns_cd4', 'icd_dgns_cd5',
                 'icd_dgns_cd6', 'icd_dgns_cd7', 'icd_dgns_cd8', 'icd_dgns_cd9', 'icd_dgns_cd10']
    icd9_prcdrs = ['icd_prcdr_cd1', 'icd_prcdr_cd2', 'icd_prcdr_cd3', 'icd_prcdr_cd4', 'icd_prcdr_cd5',
                   'icd_prcdr_cd6']

    icd9_cols = [c for c in icd9_cols if c in df.columns]
    icd9_prcdrs = [c for c in icd9_prcdrs if c in df.columns]

    iter_len = min(len(icd9_cols), len(icd9_prcdrs))

    icd_dgns = 'icd_dgns_cd1'
    icd_prcdr = 'icd_prcdr_cd1'

    if args is None:
        args = parse_args()

    # drop hospitals and rows if a hospital has less than 11 rows
    try:
        counts_prvdr = df.groupby(provider_col, sort=False)[provider_col].transform('size')
        df = df[counts_prvdr >= 11].reset_index(drop=True)  # at least 11 claims should be there
    except Exception as e:
        logger.error("MyError: Could not do cell-size compression", exc_info=True)

    # iterate to save each idc9 prcdr src distribution
    for i in range(len(icd9_cols)):
        # drop rows if (hospital, icd9code) has less than 10: cell size suppression
        icd_dgns = icd9_cols[i]
        try:
            # diagnosis src
            counts_prvdr_drg = df.groupby([provider_col, icd_dgns], sort=False)[icd_dgns].transform('size')
            cdf = df[counts_prvdr_drg >= 11].reset_index(drop=True)  # at least 11 claims should be there
        except Exception as e:
            logger.error("MyError: Could not do cell-size compression wrt ICD9", exc_info=True)

        # computing percentage share for each group and for each hospital
        try:
            cpr_grouped = cdf.groupby([provider_col])
            prvdr_icd_dgns_counts = {}
            for nm, grp_df in cpr_grouped:
                # logger.info("MyInfo: Columns in the grp_df" + ','.join(grp_df.columns))
                try:
                    count_dct_ = grp_df[icd_dgns].value_counts().to_dict()
                    prvdr_icd_dgns_counts[nm] = count_dct_
                except:
                    logger.info("MyInfo: failed prvdr_icd_dgns_counts: " + str(nm))
                    continue

            # save provider dicts
            json.dump(prvdr_icd_dgns_counts,
                      open(save_path + "provider_" + str(icd_dgns) + ".json", 'w'), cls=NumpyEncoder)
        except Exception as exception:
            logger.error("MyError: Failed computing provider icd9 distribution", exc_info=True)

        # adding drg_icd mapping
        try:
            drg_icd_dgns_mapping = cdf.groupby(drg_col)[icd_dgns].agg(lambda x: x.value_counts().to_dict()).to_dict()
            logger.info("DRG to ICD src mapping done.")
            # filter out codes that occur less than 11 times
            for key, val in drg_icd_dgns_mapping.items():
                your_dict = {k: v for k, v in val.items() if v > 10}
                drg_icd_dgns_mapping[key] = your_dict

            logger.info("DRG to ICD filtered rarely used codes.")

            # saving these data to json file
            json.dump(drg_icd_dgns_mapping, open(save_path + "drg_" + str(icd_dgns) +
                                                 "_mapping.json", 'w'))
        except Exception as exception:
            logger.error("MyError: Failed while drg icd mapping. ", exc_info=True)

        # iterate to save each idc9 prcdr src distribution
        for i in range(len(icd9_prcdrs)):
            # drop rows if (hospital, icd9code) has less than 10: cell size suppression
            icd_prcdr = icd9_prcdrs[i]
            try:
                # procedure src
                counts_prvdr_drg = df.groupby([provider_col, icd_prcdr], sort=False)[icd_prcdr].transform('size')
                pdf = df[counts_prvdr_drg >= 11].reset_index(drop=True)  # at least 11 claims should be there
            except Exception as e:
                logger.error("MyError: Could not do cell-size compression wrt ICD9", exc_info=True)

            # computing percentage share for each group and for each hospital
            try:
                ppr_grouped = pdf.groupby([provider_col])
                prvdr_icd_prcdr_counts = {}
                for nm, grp_df in ppr_grouped:
                    # logger.info("MyInfo: Columns in the grp_df" + ','.join(grp_df.columns))
                    try:
                        count_dct_ = grp_df[icd_prcdr].value_counts().to_dict()
                        prvdr_icd_prcdr_counts[nm] = count_dct_
                    except:
                        logger.info("MyInfo: failed prvdr_icd_prcdr_counts: " + str(nm))
                        continue

                # save provider dicts
                json.dump(prvdr_icd_prcdr_counts,
                          open(save_path + "provider_" + str(icd_prcdr) + ".json", 'w'), cls=NumpyEncoder)
            except Exception as exception:
                logger.error("MyError: Failed computing provider icd9 distribution", exc_info=True)

            # adding drg_icd mapping as well copied from main function....delete it later
            try:
                drg_icd_prcdr_mapping = pdf.groupby(drg_col)[icd_prcdr].agg(
                    lambda x: x.value_counts().to_dict()).to_dict()
                logger.info("DRG to ICD src mapping done.")
                # filter out codes that occur less than 11 times
                for key, val in drg_icd_prcdr_mapping.items():
                    your_dict = {k: v for k, v in val.items() if v > 10}
                    drg_icd_prcdr_mapping[key] = your_dict
                logger.info("DRG to ICD filtered rarely used codes.")

                # saving these data to json file
                json.dump(drg_icd_dgns_mapping, open(save_path + "drg_" + str(icd_prcdr) +
                                                     "_mapping.json", 'w'))
            except Exception as exception:
                logger.error("MyError: Failed while drg icd mapping. ", exc_info=True)


def compute_stats_beneficiary(args, year, grp_drg_mapping=None):
    drg_col = "drg_cd"
    provider_col = "provider"
    base_clm_pmt = "base_clm_amt"

    if args is None:
        args = parse_args()

    # make output_011221 directory for the given year
    save_path = args.savepath + year + '/'
    os.makedirs(save_path, exist_ok=True)

    df = read_data(args=args, year=year, grp_drg_mapping=grp_drg_mapping)

    # drop hospitals and rows if a hospital has less than 11 rows
    try:
        counts_prvdr = df.groupby(provider_col, sort=False)[provider_col].transform('size')
        df = df[counts_prvdr >= 11].reset_index(drop=True)  # at least 11 claims should be there
    except Exception as e:
        logger.error("MyError: Could not do cell-size compression", exc_info=True)

    # drop rows if (hospital, drg) has less than 10: cell size suppression
    try:
        counts_prvdr_drg = df.groupby([provider_col, drg_col], sort=False)[drg_col].transform('size')
        df = df[counts_prvdr_drg >= 11].reset_index(drop=True)  # at least 11 claims should be there
    except Exception as e:
        logger.error("MyError: Could not do cell-size compression wrt DRG", exc_info=True)

    try:
        # reading chronic conditions file
        bdf, bsfcc_cols = read_chronic_conditions_data(args.bsfcc, str(year - 1))

    except Exception as e:
        logger.error("MyError: In main - Could not do cell-size compression wrt DRG", exc_info=True)

    try:
        # Step 1: Merge the two df based of bene_id
        df = df.merge(bdf, left_on='bene_id', right_on='bene_id').reset_index(drop=True)
    except Exception as e:
        logger.error("MyError: In beneficiary main - Could not do cell-size compression wrt DRG", exc_info=True)

    # save distribution of price for each chronic condition
    try:
        chronic_mean = []
        chronic_std = []
        counts_chronic = []
        for col in bsfcc_cols:
            m_ = df[df[col] == 1][base_clm_pmt].mean()
            s_ = df[df[col] == 1][base_clm_pmt].std()
            n_ = len(df[df[col] == 1])
            chronic_mean.append(m_)
            chronic_std.append(s_)
            counts_chronic.append(n_)

        chr_mean_dct = dict(zip(bsfcc_cols, chronic_mean))
        chr_std_dct = dict(zip(bsfcc_cols, chronic_std))
        chr_cnt_dct = dict(zip(bsfcc_cols, counts_chronic))

        json.dump(chr_mean_dct, open(save_path + "/chr_mean_dct.json", 'w'))
        json.dump(chr_std_dct, open(save_path + "/chr_std_dct.json", 'w'))
        json.dump(chr_cnt_dct, open(save_path + "/chr_cnt_dct.json", 'w'))
    except Exception as e:
        logger.error("MyError: Computing mean std of price per chronic disease")


def get_provider_mdc_counts_dict(df, provider_col):
    # load mdc df
    mdc_col = 'MDC'

    try:
        mdc_df = pd.read_csv('../data/MS-DRG-MDC-Mapping.csv', dtype=str).fillna('99')
        drg_mdc_dct = dict(zip(mdc_df['MSDRG'], mdc_df['MDC']))
        df[mdc_col] = df["drg_cd"].map(drg_mdc_dct)

        pr_grouped = df.groupby([provider_col])
        prvdr_mdc_counts_dct = {}
        for nm, grp_df in pr_grouped:
            if isinstance(nm, tuple):
                nm = nm[0]
            try:
                # mdc counts
                dct_ = grp_df[mdc_col].value_counts().to_dict()
                prvdr_mdc_counts_dct[nm] = dct_
            except:
                logger.info("MyInfo: MDC counts failed for provider: " + str(nm))
                continue
        # save provider dicts
        # json.dump(prvdr_mdc_counts_dct, open(save_path + "provider_mdc_count.json", 'w'), cls=NumpyEncoder)
        # json.dump(mdc_drg_dct, open(save_path + "mdc_drg_dct.json", 'w'), cls=NumpyEncoder)
    except Exception as e:
        logger.error("MyError: MDC merging failed", exc_info=True)

    return prvdr_mdc_counts_dct


def get_provider_chronic_counts_dict(df, args):
    # load mdc df
    ##############################
    try:
        yr = str(int(args.year) - 1) # previous year's diagnosed chronic condition
        bdf, bsfcc_cols = read_chronic_conditions_data(args.bsfcc, yr, args.pct) 
        # get chronic conditions dataframe
        prvdr_chronic_cdns_df = chronic_cdns_df(df, bdf, bsfcc_cols)
        print("Successfully created chronic cdns df!")
    except Exception as e:
        logger.error("MyError: MDC merging failed", exc_info=True)
    return prvdr_chronic_cdns_df


def main():
    # Delete the file Completed.log if it exists to mark the start of running the program
    try:
        os.remove('Completed.log')
    except OSError as e:
        logger.error("MyError: %s - %s", e.filename, e.strerror)

    PCT = '100pct'  #'0001pct' #
    try:
        # parse argument
        args = parse_args()
        args.savepath = OUTPUTPATH

        # run the main function to record various statistics
        for yr in ['2017', '2016', '2015', '2014', '2013', '2012']:
            start = time.time() # to see how much time each yesr requires
            save_path = args.savepath + yr + '/'
            mkdir_p(save_path)

            # read data per year
            args.year = yr
            args.pct = PCT

            print("Reading data...")
            df = read_data(args=args, year=yr)

            #restrict things to ER cases
            # df = df[df['type_adm'] == '1'].reset_index(drop=True)

            drg_col = "drg_cd"
            provider_col = "provider"
            base_clm_amt = "base_clm_amt"

            # save data for creating tensor -- provider X src X beneficiary
            # provider --> beneficiary --> src
            prvdr_beneficiary_icd9dgns_dct = defaultdict(lambda: defaultdict(list))
            prvdr_beneficiary_icd9prcdr_dct = defaultdict(lambda: defaultdict(list))
            for index, row in df.iterrows():
                prvdr_beneficiary_icd9dgns_dct[row['provider']][row['bene_id']].extend([
                    row['icd_dgns_cd1'], row['icd_dgns_cd2'], row['icd_dgns_cd3'], row['icd_dgns_cd4'],
                    row['icd_dgns_cd5'], row['icd_dgns_cd6'], row['icd_dgns_cd7'], row['icd_dgns_cd8'],
                    row['icd_dgns_cd9'], row['icd_dgns_cd10']
                ])
            
                prvdr_beneficiary_icd9prcdr_dct[row['provider']][row['bene_id']].extend([
                    row['icd_prcdr_cd1'], row['icd_prcdr_cd2'], row['icd_prcdr_cd3'], row['icd_prcdr_cd4'],
                    row['icd_prcdr_cd5'], row['icd_prcdr_cd6']
                ])
            
            json.dump(prvdr_beneficiary_icd9dgns_dct, open(save_path + "prvdr_beneficiary_icd9dgns_dct.json", 'w'),
                      cls=NumpyEncoder)
            json.dump(prvdr_beneficiary_icd9prcdr_dct, open(save_path + "prvdr_beneficiary_icd9prcdr_dct.json", 'w'),
                      cls=NumpyEncoder)
            
            # save primary icd9 and prcdr distribution per provider
            pr_grouped = df.groupby([provider_col])
            prvdr_icd9_dgns_counts = {}
            prvdr_icd9_prcdr_counts = {}
            for nm, grp_df in pr_grouped:
                try:
                    # diagnosis src
                    count_dct_ = grp_df["icd_dgns_cd1"].value_counts().to_dict()
                    # cell size suppression
                    # count_dct_ = {k: v for k, v in count_dct_.items() if v > 10}
                    prvdr_icd9_dgns_counts[nm] = count_dct_
            
                    # procedure src
                    count_dct_ = grp_df["icd_prcdr_cd1"].value_counts().to_dict()
                    prvdr_icd9_prcdr_counts[nm] = count_dct_
                except:
                    logger.info("MyInfo: pct_share_amt failed for provider: " + str(nm))
                    continue
            json.dump(prvdr_icd9_dgns_counts, open(save_path + "provider_icd9_dgns.json", 'w'), cls=NumpyEncoder)
            json.dump(prvdr_icd9_prcdr_counts, open(save_path + "provider_icd9_prcdr.json", 'w'), cls=NumpyEncoder)
            # print("Processed this year!")
            
            # save median prices for each year
            try:
                # tdf = df.groupby([drg_col]).filter(lambda x: len(x) > 10). \
                #     groupby([drg_col])["base_clm_amt"].agg({"median"}).reset_index()
                tdf = df.groupby([drg_col])["base_clm_amt"].agg({"median"}).reset_index()
                json.dump(dict(zip(tdf[drg_col], tdf['median'])), open(save_path + "DRG_median_base.json", 'w'),
                          cls=NumpyEncoder)
            except Exception as exception:
                logger.error("MyError: Failed saving DRG median prices for year {}".format(yr), exc_info=True)
            
            # save DRG distribution per provider
            pr_grouped = df.groupby([provider_col])
            prvdr_drg_counts = {}
            for nm, grp_df in pr_grouped:
                try:
                    count_dct_ = grp_df[drg_col].value_counts().to_dict()
                    # cell size suppression
                    # count_dct_ = {k: v for k, v in count_dct_.items() if v > 10}
                    prvdr_drg_counts[nm] = count_dct_
                except:
                    logger.info("MyInfo: pct_share_amt failed for provider: " + str(nm))
                    continue
            json.dump(prvdr_drg_counts, open(save_path + "provider_drg_counts.json", 'w'), cls=NumpyEncoder)
            
            # save base price and total amount information
            df_base = df.groupby(provider_col)["base_clm_amt"].apply(list).reset_index(name='base_amount')
            df_amount = df.groupby(provider_col)["pmt_amt"].apply(list).reset_index(name='claim_amount')
            
            df_base_dct = dict(zip(df_base[provider_col], df_base["base_amount"]))
            df_amount_dct = dict(zip(df_amount[provider_col], df_amount["claim_amount"]))
            
            json.dump(df_base_dct, open(save_path + "prvdr_base_price_list.json", 'w'))
            json.dump(df_amount_dct, open(save_path + "prvdr_amount_list.json", 'w'))
            #
            # # save average base price and amount per drg
            df_prvdr_DRG_prices_dct = (df.groupby([provider_col, drg_col])
                                   .agg({"base_clm_amt": 'mean', "pmt_amt": 'mean'})
                                   .rename(columns={"base_clm_amt": 'base_clm_amt_avg', "pmt_amt": "pmt_amt_avg"})
                                   .to_dict(orient="index")
                                   )
            df_prvdr_DRG_prices_dct = dict((':'.join(k), v) for k, v in df_prvdr_DRG_prices_dct.items())
            json.dump(df_prvdr_DRG_prices_dct, open(save_path + "prvdr_DRG_prices_dct.json", 'w'))
            
            # get provider mdc counts
            prvdr_mdc_counts_dct = get_provider_mdc_counts_dict(df, provider_col)
            json.dump(prvdr_mdc_counts_dct, open(save_path + "prvdr_mdc_counts_dct.json", 'w'))
            
            # get provider chronic conditions counts
            prvdr_chronic_counts_df = get_provider_chronic_counts_dict(df, args)
            
            with bz2.BZ2File(save_path + "prvdr_chronic_counts_df.bz2.pkl", 'wb') as f:
                pickle.dump(prvdr_chronic_counts_df, f)

            ################ Peers based on MDC
            mdc_peers_dict, mdc_index, prvdr_codes = get_peers(args, df, prvdr_mdc_counts_dct, provider_col, drg_col,
                                                               peer_type='MDC')
            # save mdc peers for later reference
            try:
                save_mdc_peers = {}
                for k, v in mdc_peers_dict.items():
                    save_mdc_peers[k] = v.tolist()
                json.dump(save_mdc_peers, open(save_path + "mdc_peers.json", 'w'))
            except Exception as e:
                logger.error("MyError: saving mdc based peers", exc_info=True)

            ############################## Peers based on MDC and Chronic Conditions
            mdc_peers_dict, mdc_index, prvdr_codes = get_peers(args, df, prvdr_mdc_counts_dct, provider_col, drg_col,
                                                               peer_type='combined')
            # save combined mdc + chronic conditions peers for later reference
            try:
                save_mdc_peers = {}
                for k, v in mdc_peers_dict.items():
                    save_mdc_peers[k] = v.tolist()
                json.dump(save_mdc_peers, open(save_path + "mdc_chronic_peers.json", 'w'))
            except Exception as e:
                logger.error("MyError: saving mdc + chronic conditions based peers", exc_info=True)

            # save primary icd code to drg mapping
            # diagnosis codes
            icd_dgns_drg_mapping = df.groupby('icd_dgns_cd1')['drg_cd'].agg(
                lambda x: x.value_counts().to_dict()).to_dict()
            json.dump(icd_dgns_drg_mapping, open(save_path + "icd1_dgns_drg_mapping.json", 'w'))
            
            # procedure codes
            icd_prcdr_drg_mapping = df.groupby('icd_prcdr_cd1')['drg_cd'].agg(
                lambda x: x.value_counts().to_dict()).to_dict()
            json.dump(icd_prcdr_drg_mapping, open(save_path + "icd1_prcdr_drg_mapping.json", 'w'))

            # get provider to icd1 to drg mapping
            save_prvdr_icd_drg(df, save_path)

            # compute beneficiary stats
            compute_stats_beneficiary(args, yr)

            end = time.time()
            print('Elapsed time for year {} is {} seconds.'.format(yr, end - start))
            logger.info('Elapsed time for year {} is {} seconds.'.format(yr, end - start))
    except Exception as e:
        logger.error("MyError: In Main", exc_info=True)

    # recreate completed.log to mark the end of program run
    touch('Completed.log')
    print("Processed each block!")


def save_prvdr_icd_drg(df, save_path):
    grouped = df.groupby(['provider', 'icd_dgns_cd1', 'drg_cd']).size()

    nested_dict = {}

    for (key1, key2, key3), count in grouped.items():
        if key1 not in nested_dict:
            nested_dict[key1] = {}
        if key2 not in nested_dict[key1]:
            nested_dict[key1][key2] = {}

        if count > 0:
            nested_dict[key1][key2][key3] = count

    # save dict
    json.dump(nested_dict, open(save_path + "prvdr_icd1_dgns_drg_mapping.json", 'w'))
    print('Saved!')


def save_ccn(out_file):
    files = ['MUP_IHP_R19_P08_V10_D14_Prov_Svc.csv', 'MUP_IHP_R19_P08_V10_D15_Prov_Svc.csv',
             'MUP_IHP_R19_P08_V10_D16_Prov_Svc.csv', 'MUP_IHP_R19_P08_V10_D17_Prov_Svc.csv',
             'MUP_IHP_RY21_P02_V10_DY18_PrvSvc.csv', 'MUP_IHP_RY21_P02_V10_DY19_PrvSvc_0.csv',
             'MUP_INP_RY25_P03_V10_DY23_PrvSvc.csv']
    base_path = '../data/'
    df_lst = []
    for f in files:
        with open(base_path + f, 'rb') as fen:
            result = chardet.detect(fen.read())
        encoding = result['encoding']
        df_ = pd.read_csv(base_path + f, usecols=['Rndrng_Prvdr_CCN', 'Rndrng_Prvdr_Org_Name',
                                                  'Rndrng_Prvdr_City', 'Rndrng_Prvdr_State_Abrvtn',
                                                  'Rndrng_Prvdr_Zip5'], encoding=encoding)
        df_lst.append(df_)

    df = pd.concat(df_lst)
    print('Rows in original df ', len(df))
    df.columns = ['CCN', 'name', 'city', 'state', 'zip']
    df = df.drop_duplicates().reset_index(drop=True)
    print('# of hospitals in USA ', len(df))
    df.to_csv(out_file, index=False)


def load_ccn(infile):
    df = pd.read_csv(infile)
    return df

if __name__ == '__main__':
    main()

